{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nifti Read Example\n",
    "\n",
    "The purpose of this notebook is to illustrate reading Nifti files and iterating over patches of the volumes loaded from them.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/nifti_read_example.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 1.1.0+11.g7de6c336.dirty\n",
      "Numpy version: 1.22.2\n",
      "Pytorch version: 1.13.0a0+d0d6b1f\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: 7de6c33656a99087ca3b89a817b0879cf093febc\n",
      "MONAI __file__: /workspace/Code/MONAI/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.10\n",
      "Nibabel version: 4.0.2\n",
      "scikit-image version: 0.19.3\n",
      "Pillow version: 9.0.1\n",
      "Tensorboard version: 2.11.0\n",
      "gdown version: 4.6.0\n",
      "TorchVision version: 0.14.0a0\n",
      "tqdm version: 4.64.1\n",
      "lmdb version: 1.3.0\n",
      "psutil version: 5.9.2\n",
      "pandas version: 1.4.4\n",
      "einops version: 0.6.0\n",
      "transformers version: 4.21.3\n",
      "mlflow version: 2.0.1\n",
      "pynrrd version: 1.0.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import os\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from monai.config import print_config\n",
    "from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d, PatchIter\n",
    "from monai.transforms import (\n",
    "    Compose,\n",
    "    LoadImage,\n",
    "    RandSpatialCrop,\n",
    "    ScaleIntensity,\n",
    "    EnsureType,\n",
    ")\n",
    "from monai.utils import first\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/Data\n"
     ]
    }
   ],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a number of test Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    im, seg = create_test_image_3d(128, 128, 128)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"im{i}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i}.nii.gz\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a data loader which yields uniform random patches from loaded Nifti files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 1, 64, 64, 64]) torch.Size([5, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n",
    "\n",
    "imtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True, ensure_channel_first=True),\n",
    "        ScaleIntensity(),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "    ]\n",
    ")\n",
    "\n",
    "segtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True, ensure_channel_first=True),\n",
    "        RandSpatialCrop((64, 64, 64), random_size=False),\n",
    "    ]\n",
    ")\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "\n",
    "loader = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively create a data loader which yields patches in regular grid order from loaded images (Note that `GridPatchDataset(..., with_coordinates=False)` is used to ignore additional output from the input `patch_iter`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shapes: torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
     ]
    }
   ],
   "source": [
    "imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity(), EnsureType()])\n",
    "\n",
    "segtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), EnsureType()])\n",
    "\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "patch_iter = PatchIter(patch_size=(64, 64, 64), start_pos=(0, 0, 0))\n",
    "\n",
    "\n",
    "def img_seg_iter(x):\n",
    "    for im, seg in zip(patch_iter(x[0]), patch_iter(x[1])):\n",
    "        # uncomment this to confirm the coordinates\n",
    "        # print(\"coord img:\", im[1].flatten(), \"coord seg:\", seg[1].flatten())\n",
    "        yield ((im[0], seg[0]),)\n",
    "\n",
    "\n",
    "ds = GridPatchDataset(ds, img_seg_iter, with_coordinates=False)\n",
    "\n",
    "loader = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=0, pin_memory=torch.cuda.is_available())\n",
    "im, seg = first(loader)\n",
    "print(\"image shapes:\", im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
